<!DOCTYPE html>
<html lang="zh-CN">
    <head>
        <meta charset="utf-8">
        <meta name="viewport" content="width=device-width, initial-scale=1"><meta name="robots" content="noodp"/><title>TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型 | Yasin&#39;s Blog</title><meta name="twitter:card" content="summary_large_image"/>
<meta name="twitter:image" content=""/>
<meta name="twitter:title" content="TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型"/>
<meta name="twitter:description" content=""/><meta name="twitter:creator" content="@wangyuexin8"/><meta name="Description" content="KEEP KWARKING"><meta property="og:title" content="TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型" />
<meta property="og:description" content="以MNIST的sequential模型为base-line，通过读取自己的数据，训练模型并存储模型，最后达到绘图实物的运用。 自制数据集，解决" />
<meta property="og:type" content="article" />
<meta property="og:url" content="https://blog.aimoon.top/selfdatatrain/" /><meta property="og:image" content="https://blog.aimoon.top/images/favicon.svg"/><meta property="article:section" content="posts" />
<meta property="article:published_time" content="2020-06-04T23:34:53&#43;08:00" />
<meta property="article:modified_time" content="2021-03-29T11:34:14&#43;08:00" /><meta property="og:site_name" content="Yasin&#39;s Blog" />

<meta name="application-name" content="YASIN">
<meta name="apple-mobile-web-app-title" content="YASIN"><meta name="theme-color" content="#ffffff"><meta name="msapplication-TileColor" content="#da532c"><link rel="icon" href="/images/favicon.svg" type="image/x-icon"><link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png"><link rel="mask-icon" href="/safari-pinned-tab.svg" color="#5bbad5"><link rel="manifest" href="/site.webmanifest"><link rel="canonical" href="https://blog.aimoon.top/selfdatatrain/" /><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/normalize.css@8.0.1/normalize.min.css"><link rel="stylesheet" href="/css/style.min.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/animate.css@3.7.2/animate.min.css"><script type="application/ld+json">
    {
        "@context": "http://schema.org",
        "@type": "BlogPosting",
        "headline": "TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型",
        "inLanguage": "zh-CN",
        "mainEntityOfPage": {
            "@type": "WebPage",
            "@id": "https:\/\/blog.aimoon.top\/selfdatatrain\/"
        },"image": ["https:\/\/blog.aimoon.top\/images\/cover.png"],"genre": "posts","keywords": "datasets, mnist","wordCount":  2844 ,
        "url": "https:\/\/blog.aimoon.top\/selfdatatrain\/","datePublished": "2020-06-04T23:34:53+08:00","dateModified": "2021-03-29T11:34:14+08:00",
        "publisher": {
            "@type": "Person",
            "name": "Wang Yuexin", "image": [
            {
            "@type": "ImageObject",
            "url": "https:\/\/blog.aimoon.top\/images\/avatars.png"
            }
            ]},"author": {
                "@type": "Person",
                "name": "Wang Yuexin"
            },"description": ""
    }
    </script><script type="application/ld+json">
    {
        "@context": "https://schema.org",
        "@type": "BreadcrumbList",
        "itemListElement": [{
            "@type": "ListItem",
            "position": 1,
            "name": "主页",
            "item": "https:\/\/blog.aimoon.top"
        },{
            "@type": "ListItem",
            "position": 2,
            "name": "TF2.1学习笔记",
            "item": "https://blog.aimoon.top/categories/tf2.1%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/"
        },{
                "@type": "ListItem",
                "position": 3,
                "name": "TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型"
            }]
    }
</script></head>
    <body data-header-desktop="auto" data-header-mobile="auto"><script>(window.localStorage && localStorage.getItem('theme') ? localStorage.getItem('theme') === 'dark' : ('light' === 'auto' ? window.matchMedia('(prefers-color-scheme: dark)').matches : 'light' === 'dark')) && document.body.setAttribute('theme', 'dark');</script>

        <div id="mask"></div><div class="wrapper"><header>
    <div class="desktop header" id="header-desktop">
        <div class="header-wrapper">
            <div class="header-title">
                <a href="/" title="Yasin&#39;s Blog" class="header-logo logo-svg">Yasin&#39;s Blog</a>
            </div>
            <div class="menu">
                <nav>
                    <h2 class="display-hidden">Основная навигация</h2>
                    <ul class="menu-inner"><li>
                            <a class="menu-item" href="/posts/"> 目录 </a>
                        </li><li>
                            <a class="menu-item" href="/tags/"> 标签 </a>
                        </li><li>
                            <a class="menu-item" href="/categories/"> 归档 </a>
                        </li><li>
                            <a class="menu-item" href="/comments/"> 留言 </a>
                        </li><li>
                            <a class="menu-item" href="https://aimoon.top" rel="noopener noreffer" target="_blank"> 主页 </a>
                        </li></ul>
                </nav><span class="menu-item delimiter"></span><span class="menu-item search" id="search-desktop">
                        <input type="text" placeholder="search……" id="search-input-desktop">
                        <a href="javascript:void(0);" class="search-button search-toggle" id="search-toggle-desktop" title="搜索">
                            <span class="svg-icon icon-search"></span>
                        </a>
                        <a href="javascript:void(0);" class="search-button search-clear" id="search-clear-desktop" title="清空">
                            <span class="svg-icon icon-cancel"></span>
                        </a>
                        <span class="search-button search-loading" id="search-loading-desktop">
                            <span class="svg-icon icon-loading"></span>
                        </span>
                    </span><a href="javascript:void(0);" class="menu-item theme-switch" title="切换主题">
                <span class="svg-icon icon-moon"></span>
                </a>
            </div>
        </div>
    </div><div class="mobile header" id="header-mobile">
        <div class="header-container">
            <div class="header-wrapper">
                <div class="header-title">
                    <a href="/" title="Yasin&#39;s Blog" class="header-logo">Yasin&#39;s Blog</a>
                </div>
                <div class="menu-toggle" id="menu-toggle-mobile">
                    <span></span><span></span><span></span>
                </div>
            </div>
            <div class="menu" id="menu-mobile"><div class="search-wrapper">
                        <div class="search mobile" id="search-mobile">
                            <input type="text" placeholder="search……" id="search-input-mobile">
                            <a href="javascript:void(0);" class="search-button search-toggle" id="search-toggle-mobile" title="搜索">
                                <span class="svg-icon icon-search"></span>
                            </a>
                            <a href="javascript:void(0);" class="search-button search-clear" id="search-clear-mobile" title="清空">
                                <span class="svg-icon icon-cancel"></span>
                            </a>
                            <span class="search-button search-loading" id="search-loading-mobile">
                                <span class="svg-icon icon-loading"></span>
                            </span>
                        </div>
                        <a href="javascript:void(0);" class="search-cancel" id="search-cancel-mobile">
                            取消
                        </a>
                    </div><nav>
                    <h2 class="display-hidden">Основная навигация</h2>
                    <ul><li>
                            <a class="menu-item" href="/posts/" title="">目录</a>
                        </li><li>
                            <a class="menu-item" href="/tags/" title="">标签</a>
                        </li><li>
                            <a class="menu-item" href="/categories/" title="">归档</a>
                        </li><li>
                            <a class="menu-item" href="/comments/" title="">留言</a>
                        </li><li>
                            <a class="menu-item" href="https://aimoon.top" title="" rel="noopener noreffer" target="_blank">主页</a>
                        </li></ul>
                </nav>
                <a href="javascript:void(0);" class="menu-item theme-switch" title="切换主题">
                    <span class="svg-icon icon-moon"></span>
                </a></div>
        </div>
    </div><div class="search-dropdown desktop">
    <div id="search-dropdown-desktop"></div>
</div>
<div class="search-dropdown mobile">
    <div id="search-dropdown-mobile"></div>
</div></header><main class="main">
<div class="container content-article page-toc theme-classic"><div class="toc" id="toc-auto">
            <div class="toc-title">目录</div>
            <div class="toc-content" id="toc-content-auto"></div>
        </div>
    

    
    
    <article>
    

        <header class="header-post">

            

            
            <div class="post-title">

                    <div class="post-all-meta">
                        <nav class="breadcrumbs">
    <ol>
        <li><a href="/">主页 </a></li><li><a href="/categories/tf2.1%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/">TF2.1学习笔记 </a></li><li>TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型</li>
    </ol>
</nav>
                        <h1 class="single-title flipInX">TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型</h1><div class="post-meta summary-post-meta"><span class="post-category meta-item">
                                <a href="/categories/tf2.1%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/"><span class="svg-icon icon-folder"></span>TF2.1学习笔记</a>
                            </span><span class="post-meta-date meta-item">
                                <span class="svg-icon icon-clock"></span><time class="timeago" datetime="2020-06-04">2020-06-04</time>
                            </span><span class="post-meta-words meta-item">
                                <span class="svg-icon icon-pencil"></span>约 2844 字
                            </span>
                            <span class="post-meta-reading meta-item">
                                <span class="svg-icon icon-stopwatch"></span>预计阅读 6 分钟
                            </span>
                        </div>

                    </div>

                </div>

                </header>

        <div class="article-post toc-start">

            <div class="content-block content-block-first content-block-position">

                <div class="post single"><div class="image-theme-classic">
                        <img src="https://img-blog.csdnimg.cn/2020060500240012.png" style="width: 100%">
                    </div><div class="details toc" id="toc-static"  data-kept="">
                        <div class="details-summary toc-title">
                            <span>目录</span>
                        </div>
                        <div class="details-content toc-content" id="toc-content-static"><nav id="TableOfContents">
  <ul>
    <li><a href="#自制数据集解决本领域应用">自制数据集，解决本领域应用</a>
      <ul>
        <li><a href="#观察数据结构">观察数据结构</a></li>
        <li><a href="#def-generateds图片路径标签文件">def generateds(图片路径,标签文件)：</a></li>
      </ul>
    </li>
    <li><a href="#数据增强扩充数据集">数据增强，扩充数据集</a></li>
    <li><a href="#断点续训存取模型">断点续训，存取模型</a>
      <ul>
        <li><a href="#读取保存模型">读取保存模型</a></li>
        <li><a href="#保存模型">保存模型</a></li>
      </ul>
    </li>
    <li><a href="#参数提取把参数存入文本">参数提取，把参数存入文本</a></li>
    <li><a href="#accloss可视化查看训练效果">acc/loss可视化，查看训练效果</a></li>
    <li><a href="#应用程序给图识物">应用程序，给图识物</a></li>
  </ul>
</nav></div>
                    </div><p>以MNIST的<a href="https://blog.aimoon.top/2020/05/tf_keras_mnist/#%E4%BD%BF%E7%94%A8sequential%E5%AE%9E%E7%8E%B0%E6%89%8B%E5%86%99%E5%AD%97%E4%BD%93%E8%AF%86%E5%88%AB" target="_blank" rel="noopener noreffer">sequential模型</a>为base-line，通过读取自己的数据，训练模型并存储模型，最后达到绘图实物的运用。</p>
<h2 id="自制数据集解决本领域应用" class="headerLink"><a href="#%e8%87%aa%e5%88%b6%e6%95%b0%e6%8d%ae%e9%9b%86%e8%a7%a3%e5%86%b3%e6%9c%ac%e9%a2%86%e5%9f%9f%e5%ba%94%e7%94%a8" class="header-mark"></a>自制数据集，解决本领域应用</h2><h3 id="观察数据结构" class="headerLink"><a href="#%e8%a7%82%e5%af%9f%e6%95%b0%e6%8d%ae%e7%bb%93%e6%9e%84" class="header-mark"></a>观察数据结构</h3><p>给x_train、y_train、x_test、y_test赋值</p>
<p>




<img loading="lazy" decoding="async"
         class="render-image"
         src="https://img-blog.csdnimg.cn/20200604000506672.png"
         alt="https://img-blog.csdnimg.cn/20200604000506672.png"
         title="20200604000506672.png"
    /></p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604212835629.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604212835629.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<h3 id="def-generateds图片路径标签文件" class="headerLink"><a href="#def-generateds%e5%9b%be%e7%89%87%e8%b7%af%e5%be%84%e6%a0%87%e7%ad%be%e6%96%87%e4%bb%b6" class="header-mark"></a>def generateds(图片路径,标签文件)：</h3><div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">generateds</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">txt</span><span class="p">):</span>
    <span class="n">f</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="n">txt</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span>  <span class="c1"># 以只读形式打开txt文件</span>
    <span class="n">contents</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">readlines</span><span class="p">()</span>  <span class="c1"># 读取文件中所有行</span>
    <span class="n">f</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>  <span class="c1"># 关闭txt文件</span>
    <span class="n">x</span><span class="p">,</span> <span class="n">y_</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>  <span class="c1"># 建立空列表</span>
    <span class="k">for</span> <span class="n">content</span> <span class="ow">in</span> <span class="n">contents</span><span class="p">:</span>  <span class="c1"># 逐行取出</span>
        <span class="n">value</span> <span class="o">=</span> <span class="n">content</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>  <span class="c1"># 以空格分开，图片路径为value[0] , 标签为value[1] , 存入列表</span>
        <span class="n">img_path</span> <span class="o">=</span> <span class="n">path</span> <span class="o">+</span> <span class="n">value</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>  <span class="c1"># 拼出图片路径和文件名</span>
        <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>  <span class="c1"># 读入图片</span>
        <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;L&#39;</span><span class="p">))</span>  <span class="c1"># 图片变为8位宽灰度值的np.array格式</span>
        <span class="n">img</span> <span class="o">=</span> <span class="n">img</span> <span class="o">/</span> <span class="mf">255.</span>  <span class="c1"># 数据归一化 （实现预处理）</span>
        <span class="n">x</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>  <span class="c1"># 归一化后的数据，贴到列表x</span>
        <span class="n">y_</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">value</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>  <span class="c1"># 标签贴到列表y_</span>
        <span class="k">print</span><span class="p">(</span><span class="s1">&#39;loading : &#39;</span> <span class="o">+</span> <span class="n">content</span><span class="p">)</span>  <span class="c1"># 打印状态提示</span>

    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># 变为np.array格式</span>
    <span class="n">y_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">y_</span><span class="p">)</span>  <span class="c1"># 变为np.array格式</span>
    <span class="n">y_</span> <span class="o">=</span> <span class="n">y_</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>  <span class="c1"># 变为64位整型</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y_</span>  <span class="c1"># 返回输入特征x，返回标签y_</span>
</code></pre></td></tr></table>
</div>
</div><h2 id="数据增强扩充数据集" class="headerLink"><a href="#%e6%95%b0%e6%8d%ae%e5%a2%9e%e5%bc%ba%e6%89%a9%e5%85%85%e6%95%b0%e6%8d%ae%e9%9b%86" class="header-mark"></a>数据增强，扩充数据集</h2><p>数据增强（增大数据量），可以简单的扩展数据集，对图像的数据增强就是对图像的简单形变。</p>
<p>tensorflow2中的数据增强函数</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">image_gen_train</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">preprocessing</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">ImageDataGenerator</span><span class="p">(</span>
	<span class="n">rescale</span> <span class="o">=</span> <span class="err">所有数据将乘以该数值</span>
	<span class="n">rotation_range</span> <span class="o">=</span> <span class="err">随机旋转角度数范围</span>
	<span class="n">width_shift_range</span> <span class="o">=</span> <span class="err">随机宽度偏移量</span>
	<span class="n">height_shift_range</span> <span class="o">=</span> <span class="err">随机高度偏移量</span>
	<span class="err">水平翻转：</span><span class="n">horizontal_flip</span> <span class="o">=</span> <span class="err">是否随机水平翻转</span>
	<span class="err">随机缩放：</span><span class="n">zoom_range</span> <span class="o">=</span> <span class="err">随机缩放的范围</span> <span class="p">[</span><span class="mi">1</span><span class="o">-</span><span class="n">n</span><span class="err">，</span><span class="mi">1</span><span class="o">+</span><span class="n">n</span><span class="p">]</span> <span class="p">)</span>
<span class="n">image_gen_train</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span>

<span class="c1">### 例 ###</span>
<span class="n">image_gen_train</span> <span class="o">=</span> <span class="n">ImageDataGenerator</span><span class="p">(</span>
	<span class="n">rescale</span><span class="o">=</span><span class="mf">1.</span> <span class="o">/</span> <span class="mf">1.</span><span class="p">,</span> <span class="c1"># 如为图像，分母为255时，可归至0～1</span>
	<span class="n">rotation_range</span><span class="o">=</span><span class="mi">45</span><span class="p">,</span> <span class="c1"># 随机45度旋转</span>
	<span class="n">width_shift_range</span><span class="o">=.</span><span class="mi">15</span><span class="p">,</span> <span class="c1"># 宽度偏移</span>
	<span class="n">height_shift_range</span><span class="o">=.</span><span class="mi">15</span><span class="p">,</span> <span class="c1"># 高度偏移</span>
	<span class="n">horizontal_flip</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="c1"># 水平翻转</span>
	<span class="n">zoom_range</span><span class="o">=</span><span class="mf">0.5</span> <span class="c1"># 将图像随机缩放阈量50％)</span>
<span class="n">image_gen_train</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p>其中image_gen_train.fit(x_train)中的fit需要一个四维数组
即：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p>(60000, 28, 28) $\Rightarrow$ (60000, 28, 28, 1)
将60000张28行28列的数据转换成60000张28行28列单通道的数据集，其中“1”是灰度值
model.fit()同步更新为model.fit(image_gen_train.flow(x_train, y_train,batch_size=32), ……)</p>
<center>model.fit(x_train, y_train,batch_size=32, ……)</center>
<p>$$\Downarrow$$</p>
<center>model.fit(image_gen_train.flow(x_train, y_train,batch_size=32), ……)</center>
<p>加入数据增强的的代码训练后</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>
<span class="kn">from</span> <span class="nn">tensorflow.keras.preprocessing.image</span> <span class="kn">import</span> <span class="n">ImageDataGenerator</span>

<span class="n">mnist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span>
<span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">/</span> <span class="mf">255.0</span>
<span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>  <span class="c1"># 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28, 28, 1)</span>

<span class="n">image_gen_train</span> <span class="o">=</span> <span class="n">ImageDataGenerator</span><span class="p">(</span>
    <span class="n">rescale</span><span class="o">=</span><span class="mf">1.</span> <span class="o">/</span> <span class="mf">1.</span><span class="p">,</span>  <span class="c1"># 如为图像，分母为255时，可归至0～1</span>
    <span class="n">rotation_range</span><span class="o">=</span><span class="mi">45</span><span class="p">,</span>  <span class="c1"># 随机45度旋转</span>
    <span class="n">width_shift_range</span><span class="o">=.</span><span class="mi">15</span><span class="p">,</span>  <span class="c1"># 宽度偏移</span>
    <span class="n">height_shift_range</span><span class="o">=.</span><span class="mi">15</span><span class="p">,</span>  <span class="c1"># 高度偏移</span>
    <span class="n">horizontal_flip</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>  <span class="c1"># 水平翻转</span>
    <span class="n">zoom_range</span><span class="o">=</span><span class="mf">0.5</span>  <span class="c1"># 将图像随机缩放阈量50％</span>
<span class="p">)</span>
<span class="n">image_gen_train</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">)</span>
<span class="p">])</span>

<span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">&#39;adam&#39;</span><span class="p">,</span>
              <span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">])</span>

<span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">image_gen_train</span><span class="o">.</span><span class="n">flow</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span>
          <span class="n">validation_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span>

</code></pre></td></tr></table>
</div>
</div><p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604231346403.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604231346403.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<ul>
<li>随着模型迭代轮数的增加，模型的准确率不断提高</li>
<li>数据在小数据量上可以增加模型的泛化性</li>
</ul>
<h2 id="断点续训存取模型" class="headerLink"><a href="#%e6%96%ad%e7%82%b9%e7%bb%ad%e8%ae%ad%e5%ad%98%e5%8f%96%e6%a8%a1%e5%9e%8b" class="header-mark"></a>断点续训，存取模型</h2><h3 id="读取保存模型" class="headerLink"><a href="#%e8%af%bb%e5%8f%96%e4%bf%9d%e5%ad%98%e6%a8%a1%e5%9e%8b" class="header-mark"></a>读取保存模型</h3><p>load_weights(路径文件名)</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">checkpoint_save_path</span> <span class="o">=</span> <span class="s2">&#34;./checkpoint/fashion.ckpt&#34;</span>	<span class="c1">#先定义出存放模型的路径和文件名，“.ckpt”文件在生成时会同步生成索引表</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">checkpoint_save_path</span> <span class="o">+</span> <span class="s1">&#39;.index&#39;</span><span class="p">):</span>		<span class="c1">#判断是否有索引表，就可以知道是否报存过模型，如果有索引表，就会调用load_weights()即模型</span>
    <span class="k">print</span><span class="p">(</span><span class="s1">&#39;-------------load the model-----------------&#39;</span><span class="p">)</span>
    <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">checkpoint_save_path</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><h3 id="保存模型" class="headerLink"><a href="#%e4%bf%9d%e5%ad%98%e6%a8%a1%e5%9e%8b" class="header-mark"></a>保存模型</h3><p>使用tensorflow给出的回调函数直接保存训练的参数：</p>
<p>tf.keras.callbacks.ModelCheckpoint(filepath=路径文件名,save_weights_only=True/False,save_best_only=True/False)</p>
<p>history = model.fit（ callbacks=[cp_callback] ）</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">cp_callback</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="n">checkpoint_save_path</span><span class="p">,</span>		<span class="c1"># 文件存储路径</span>
                                                 <span class="n">save_weights_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>			<span class="c1"># 是否只保留模型参数</span>
                                                 <span class="n">save_best_only</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>				<span class="c1"># 是否只保留模型最优参数</span>

<span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> 						<span class="c1"># 加入callbacks选项，记录到history中</span>
					<span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">validation_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                    <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">cp_callback</span><span class="p">])</span>
</code></pre></td></tr></table>
</div>
</div><p>加入断点续训的完整代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">os</span>		<span class="c1"># 引入os模块，（文件处理）</span>

<span class="n">mnist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span>
<span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">/</span> <span class="mf">255.0</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">)</span>
<span class="p">])</span>

<span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">&#39;adam&#39;</span><span class="p">,</span>
              <span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">])</span>

<span class="n">checkpoint_save_path</span> <span class="o">=</span> <span class="s2">&#34;./checkpoint/fashion.ckpt&#34;</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">checkpoint_save_path</span> <span class="o">+</span> <span class="s1">&#39;.index&#39;</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s1">&#39;-------------load the model-----------------&#39;</span><span class="p">)</span>
    <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">checkpoint_save_path</span><span class="p">)</span>

<span class="n">cp_callback</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="n">checkpoint_save_path</span><span class="p">,</span>
                                                 <span class="n">save_weights_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                                                 <span class="n">save_best_only</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">validation_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                    <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">cp_callback</span><span class="p">])</span>
<span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span>

</code></pre></td></tr></table>
</div>
</div><p>第一次运行：</p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604233937210.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604233937210.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure>





<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604234150899.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604234150899.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<p>第二次运行：</p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604234607210.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604234607210.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure>





<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200604234949472.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200604234949472.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<p>加载了第一次保存的参数，准确率在第一次的基础上提高</p>
<h2 id="参数提取把参数存入文本" class="headerLink"><a href="#%e5%8f%82%e6%95%b0%e6%8f%90%e5%8f%96%e6%8a%8a%e5%8f%82%e6%95%b0%e5%ad%98%e5%85%a5%e6%96%87%e6%9c%ac" class="header-mark"></a>参数提取，把参数存入文本</h2><ul>
<li>提取可训练参数
model.trainable_variables 返回模型中可训练的参数</li>
<li>设置print输出格式
np.set_printoptions(threshold=超过多少省略显示)</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span> <span class="c1"># np.inf表示无限大</span>
</code></pre></td></tr></table>
</div>
</div><ul>
<li>将可训练参数存入文本</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span>
<span class="nb">file</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="s1">&#39;./weights.txt&#39;</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">:</span>
	<span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
	<span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
	<span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">file</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
</code></pre></td></tr></table>
</div>
</div><p>在断点续训的基础上加入参数提取</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>

<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span>

<span class="n">mnist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span>
<span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">/</span> <span class="mf">255.0</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">)</span>
<span class="p">])</span>

<span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">&#39;adam&#39;</span><span class="p">,</span>
              <span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">])</span>

<span class="n">checkpoint_save_path</span> <span class="o">=</span> <span class="s2">&#34;./checkpoint/fashion.ckpt&#34;</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">checkpoint_save_path</span> <span class="o">+</span> <span class="s1">&#39;.index&#39;</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s1">&#39;-------------load the model-----------------&#39;</span><span class="p">)</span>
    <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">checkpoint_save_path</span><span class="p">)</span>

<span class="n">cp_callback</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="n">checkpoint_save_path</span><span class="p">,</span>
                                                 <span class="n">save_weights_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                                                 <span class="n">save_best_only</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">validation_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                    <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">cp_callback</span><span class="p">])</span>
<span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span>
<span class="nb">file</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="s1">&#39;./weights.txt&#39;</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">:</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">file</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>

</code></pre></td></tr></table>
</div>
</div><p>运行得到weights.txt文件</p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/20200605000228419.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/20200605000228419.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<h2 id="accloss可视化查看训练效果" class="headerLink"><a href="#accloss%e5%8f%af%e8%a7%86%e5%8c%96%e6%9f%a5%e7%9c%8b%e8%ae%ad%e7%bb%83%e6%95%88%e6%9e%9c" class="header-mark"></a>acc/loss可视化，查看训练效果</h2><ul>
<li>将训练过程可视化出来
在history中同步记录了训练集loss、测试机loss、训练集准确率和测试集准确率</li>
<li>history=model.fit(训练集数据, 训练集标签, batch_size=, epochs=,validation_split=用作测试数据的比例,validation_data=测试集,validation_freq=测试频率)</li>
<li>history
训练集loss： loss
测试集loss： val_loss
训练集准确率： sparse_categorical_accuracy
测试集准确率： val_sparse_categorical_accuracy</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">acc</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">]</span>
<span class="n">val_acc</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;val_sparse_categorical_accuracy&#39;</span><span class="p">]</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;loss&#39;</span><span class="p">]</span>
<span class="n">val_loss</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;val_loss&#39;</span><span class="p">]</span>	<span class="c1">#使用history.histor[]提取</span>
</code></pre></td></tr></table>
</div>
</div><p>加入绘制图像的代码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span><span class="lnt">50
</span><span class="lnt">51
</span><span class="lnt">52
</span><span class="lnt">53
</span><span class="lnt">54
</span><span class="lnt">55
</span><span class="lnt">56
</span><span class="lnt">57
</span><span class="lnt">58
</span><span class="lnt">59
</span><span class="lnt">60
</span><span class="lnt">61
</span><span class="lnt">62
</span><span class="lnt">63
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>		<span class="c1"># 导入绘图模块</span>

<span class="n">np</span><span class="o">.</span><span class="n">set_printoptions</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span>

<span class="n">mnist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span>
<span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">/</span> <span class="mf">255.0</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">)</span>
<span class="p">])</span>

<span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">&#39;adam&#39;</span><span class="p">,</span>
              <span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">])</span>

<span class="n">checkpoint_save_path</span> <span class="o">=</span> <span class="s2">&#34;./checkpoint/fashion.ckpt&#34;</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">checkpoint_save_path</span> <span class="o">+</span> <span class="s1">&#39;.index&#39;</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s1">&#39;-------------load the model-----------------&#39;</span><span class="p">)</span>
    <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">checkpoint_save_path</span><span class="p">)</span>

<span class="n">cp_callback</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="n">checkpoint_save_path</span><span class="p">,</span>
                                                 <span class="n">save_weights_only</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                                                 <span class="n">save_best_only</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">),</span> <span class="n">validation_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                    <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">cp_callback</span><span class="p">])</span>
<span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span>
<span class="nb">file</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="s1">&#39;./weights.txt&#39;</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">:</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="nb">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">file</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>

<span class="c1">###############################################    show   ###############################################</span>

<span class="c1"># 显示训练集和验证集的acc和loss曲线</span>
<span class="n">acc</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;sparse_categorical_accuracy&#39;</span><span class="p">]</span>
<span class="n">val_acc</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;val_sparse_categorical_accuracy&#39;</span><span class="p">]</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;loss&#39;</span><span class="p">]</span>
<span class="n">val_loss</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">&#39;val_loss&#39;</span><span class="p">]</span>

<span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">acc</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Training Accuracy&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">val_acc</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Validation Accuracy&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s1">&#39;Training and Validation Accuracy&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>

<span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Training Loss&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">val_loss</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Validation Loss&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s1">&#39;Training and Validation Loss&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>

</code></pre></td></tr></table>
</div>
</div><p>可视化结果：</p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/2020060500162050.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/2020060500162050.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<h2 id="应用程序给图识物" class="headerLink"><a href="#%e5%ba%94%e7%94%a8%e7%a8%8b%e5%ba%8f%e7%bb%99%e5%9b%be%e8%af%86%e7%89%a9" class="header-mark"></a>应用程序，给图识物</h2><p>前面已经将模型训练好了，下面将编写一套运用程序实现给图识物</p>
<p>




<figure class="render-image"><a target="_blank" href="https://img-blog.csdnimg.cn/2020060500240012.png" title=" " >
        <img loading="lazy" decoding="async"
             class="render-image"
             src="https://img-blog.csdnimg.cn/2020060500240012.png"
            alt=" "
        />
    </a><figcaption class="image-caption"> </figcaption>
</figure></p>
<ul>
<li>predict(输入特征，batch_size=整数)
返回前向传播计算结果</li>
</ul>
<p>前向传播执行应用：</p>
<ol>
<li>复现模型（前向传播）</li>
<li>加载参数：model.load_weights(model_save_path)</li>
<li>预测结果：result = model.predict(x_predict)</li>
</ol>
<p>源码：</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="kn">as</span> <span class="nn">tf</span>

<span class="n">model_save_path</span> <span class="o">=</span> <span class="s1">&#39;./checkpoint/mnist.ckpt&#39;</span>		

<span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span>						<span class="c1"># 复现网络</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span>
    <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">)])</span>
    
<span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">model_save_path</span><span class="p">)</span> 						<span class="c1"># 加载参数</span>

<span class="n">preNum</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">input</span><span class="p">(</span><span class="s2">&#34;input the number of test pictures:&#34;</span><span class="p">))</span>	<span class="c1"># 准备预测多少个数</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">preNum</span><span class="p">):</span>										<span class="c1"># 读入待识别的图片</span>
    <span class="n">image_path</span> <span class="o">=</span> <span class="nb">input</span><span class="p">(</span><span class="s2">&#34;the path of test picture:&#34;</span><span class="p">)</span>
    <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span>							
    <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">),</span> <span class="n">Image</span><span class="o">.</span><span class="n">ANTIALIAS</span><span class="p">)</span>				<span class="c1"># 转换成（28，28）的类型，与训练数据类型匹配</span>
    <span class="n">img_arr</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;L&#39;</span><span class="p">))</span>					<span class="c1"># 转换成灰度图</span>

    <span class="n">img_arr</span> <span class="o">=</span> <span class="mi">255</span> <span class="o">-</span> <span class="n">img_arr</span>									<span class="c1"># 将“白底黑字”反转成“黑底白字”</span>

	<span class="c1">#####or#####</span>
<span class="c1">#	for i in range(28):							# 转换成高对比度的图，过滤噪声</span>
<span class="c1">#		for j in range(28):</span>
<span class="c1">#			if img_arr[i][j] &lt; 200:</span>
<span class="c1">#				img_arr[i][j] = 255</span>
<span class="c1">#            else:</span>
<span class="c1">#            	img_arr[i][j] = 0</span>

   
    <span class="n">img_arr</span> <span class="o">=</span> <span class="n">img_arr</span> <span class="o">/</span> <span class="mf">255.0</span>								<span class="c1"># 归一化</span>
    <span class="k">print</span><span class="p">(</span><span class="s2">&#34;img_arr:&#34;</span><span class="p">,</span><span class="n">img_arr</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">x_predict</span> <span class="o">=</span> <span class="n">img_arr</span><span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>					<span class="c1"># 由于是按每个batch送入网络，故添加一个维度</span>
    <span class="k">print</span><span class="p">(</span><span class="s2">&#34;x_predict:&#34;</span><span class="p">,</span><span class="n">x_predict</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_predict</span><span class="p">)</span>						<span class="c1">#预测结果</span>
    
    <span class="n">pred</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    
    <span class="k">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
    <span class="n">tf</span><span class="o">.</span><span class="k">print</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>

</code></pre></td></tr></table>
</div>
</div></div><footer>
                        <div class="post">


<div class="post-share"><div class="share-link">
        <a class="share-icon share-twitter" href="javascript:void(0);" title="分享到 Twitter" data-sharer="twitter" data-url="https://blog.aimoon.top/selfdatatrain/" data-title="TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型" data-via="wangyuexin8" data-hashtags="datasets,mnist"><span class="svg-social-icon icon-twitter"></span></a>
    </div><div class="share-link">
        <a class="share-icon share-facebook" href="javascript:void(0);" title="分享到 Facebook" data-sharer="facebook" data-url="https://blog.aimoon.top/selfdatatrain/" data-hashtag="datasets"><span class="svg-social-icon icon-facebook"></span></a>
    </div><div class="share-link">
        <a class="share-icon share-whatsapp" href="javascript:void(0);" title="分享到 WhatsApp" data-sharer="whatsapp" data-url="https://blog.aimoon.top/selfdatatrain/" data-title="TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型" data-web><span class="svg-social-icon icon-whatsapp"></span></a>
    </div><div class="share-link">
        <a class="share-icon share-blogger" href="javascript:void(0);" title="分享到 Blogger" data-sharer="blogger" data-url="https://blog.aimoon.top/selfdatatrain/" data-title="TensorFlow2.1入门学习笔记(11)——自制数据集，并记录训练模型" data-description=""><span class="svg-social-icon icon-blogger"></span></a>
    </div></div>

<div class="footer-post-author">
    <div class="author-avatar"><a href="https://aimoon.top" target="_blank"><img alt="Undergraduate Student of Artificial Intelligence 😜" src="https://blog.aimoon.top/images/avatars.png"></a></div>
    <div class="author-info">
        <div class="name"><a href="https://aimoon.top" target="_blank">Wang Yuexin</a></div>
        <div class="number-posts">Undergraduate Student of Artificial Intelligence 😜</span></div>
    </div>
</div><div class="post-tags"><a href="/tags/datasets/" class="tag">datasets</a><a href="/tags/mnist/" class="tag">mnist</a></div></div>
                </footer></div>
        <div id="toc-final"></div>
        </div>

    
    </article>
    <section class="page single comments content-block-position">
        <h1 class="display-hidden">Комментарии</h1><div id="comments"><div id="disqus_thread" class="comment" style="padding-top: 1.5rem"></div>
            <noscript>
                Please enable JavaScript to view the comments powered by <a href="https://disqus.com/?ref_noscript">Disqus</a>.
            </noscript></div></section></div>

</main><footer class="footer">
        <div class="footer-container"><div class="footer-line"><div><span id="timeDate">正在烧脑计算建站时间...</span><span id="times"></span><script>var now = new Date();function createtime(){var grt= new Date("05/20/2020 00:00:00");now.setTime(now.getTime()+250);days = (now - grt ) / 1000 / 60 / 60 / 24;dnum = Math.floor(days);hours = (now - grt ) / 1000 / 60 / 60 - (24 * dnum);hnum = Math.floor(hours);if(String(hnum).length ==1 ){hnum = "0" + hnum; }minutes = (now - grt ) / 1000 /60 - (24 * 60 * dnum) - (60 * hnum);mnum = Math.floor(minutes);if(String(mnum).length ==1 ){mnum = "0" + mnum;}seconds = (now - grt ) / 1000 - (24 * 60 * 60 * dnum) - (60 * 60 * hnum) - (60 * mnum);snum = Math.round(seconds);if(String(snum).length ==1 ){snum = "0" + snum;}document.getElementById("timeDate").innerHTML = "&nbsp"+dnum+"&nbsp天";document.getElementById("times").innerHTML = hnum + "&nbsp小时&nbsp" + mnum + "&nbsp分&nbsp" + snum + "&nbsp秒";}setInterval("createtime()",250);</script></div></div><div class="footer-line"><i class="svg-icon icon-copyright"></i><span>2020 - 2021</span><span class="author">&nbsp;<a href="https://aimoon.top" target="_blank">Yasin</a></span>&nbsp;|&nbsp;<span class="license"><a rel="license external nofollow noopener noreffer" href="https://creativecommons.org/licenses/by-nc/4.0/" target="_blank">CC BY-NC 4.0</a></span><span class="icp-splitter">&nbsp;|&nbsp;</span><br class="icp-br"/>
                    <span class="icp"><a href="https://blog.pangao.vip/icp/xmoon.info">🧑ICP证000000号</a></span></div>
        </div>
    </footer></div>

        <aside id="fixed-buttons"><a href="#" id="back-to-top" class="fixed-button" title="回到顶部">
                <i class="svg-icon icon-arrow-up"></i>
            </a><a href="#" id="view-comments" class="fixed-button" title="查看评论">
                <i class="svg-icon icon-comments-fixed"></i>
            </a>
        </aside><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/contrib/copy-tex.min.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/cookieconsent@3.1.1/build/cookieconsent.min.css"><script src="https://yasin5.disqus.com/embed.js" defer></script><script src="https://cdn.jsdelivr.net/npm/smooth-scroll@16.1.3/dist/smooth-scroll.min.js"></script><script src="https://cdn.jsdelivr.net/npm/autocomplete.js@0.37.1/dist/autocomplete.min.js"></script><script src="https://cdn.jsdelivr.net/npm/lunr@2.3.8/lunr.min.js"></script><script src="/lib/lunr/lunr.stemmer.support.min.js"></script><script src="/lib/lunr/lunr.zh.min.js"></script><script src="https://cdn.jsdelivr.net/npm/twemoji@13.0.0/dist/twemoji.min.js"></script><script src="https://cdn.jsdelivr.net/npm/clipboard@2.0.6/dist/clipboard.min.js"></script><script src="https://cdn.jsdelivr.net/npm/sharer.js@0.4.0/sharer.min.js"></script><script src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.js"></script><script src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/contrib/auto-render.min.js"></script><script src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/contrib/copy-tex.min.js"></script><script src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/contrib/mhchem.min.js"></script><script src="https://cdn.jsdelivr.net/npm/cookieconsent@3.1.1/build/cookieconsent.min.js"></script><script>window.config={"code":{"copyTitle":"复制到剪贴板","maxShownLines":10},"comment":{},"cookieconsent":{"content":{"dismiss":"同意","link":"了解更多","message":"本网站使用 Cookies 来改善您的浏览体验."},"enable":true,"palette":{"button":{"background":"#f0f0f0"},"popup":{"background":"#1aa3ff"}},"theme":"edgeless"},"math":{"delimiters":[{"display":true,"left":"$$","right":"$$"},{"display":true,"left":"\\[","right":"\\]"},{"display":false,"left":"$","right":"$"},{"display":false,"left":"\\(","right":"\\)"}],"strict":false},"search":{"highlightTag":"em","lunrIndexURL":"/index.json","lunrLanguageCode":"zh","lunrSegmentitURL":"/lib/lunr/lunr.segmentit.js","maxResultLength":10,"noResultsFound":"没有找到结果","snippetLength":30,"type":"lunr"},"twemoji":true};</script><script src="/js/theme.min.js"></script><script>
                (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
                (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
                m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
                })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

	        ga('create', 'UA-167439955-2', 'auto');
	        ga('set', 'anonymizeIp', true);
	        ga('send', 'pageview');
	    </script></body>
</html>
